package gov.cms.grouper.snf.util;

import gov.cms.grouper.snf.SnfTables;
import gov.cms.grouper.snf.component.r2.logic.nursing.Const;
import gov.cms.grouper.snf.lego.SnfCache;
import gov.cms.grouper.snf.lego.SnfCache.SupplierCache;
import gov.cms.grouper.snf.lego.SnfComparator;
import gov.cms.grouper.snf.lego.SnfUtils;
import gov.cms.grouper.snf.lego.TriFunction;
import gov.cms.grouper.snf.model.Assessment;
import gov.cms.grouper.snf.model.SnfDiagnosisCode;
import gov.cms.grouper.snf.model.reader.Rai300;
import gov.cms.grouper.snf.model.table.BasicRow;
import gov.cms.grouper.snf.model.table.PerformanceRecodeRow;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.stream.Collectors;

public class ClaimInfo {

  private final int dataVersion;
  private final Map<String, Assessment> byName;
  private final Map<String, Assessment> byItem;
  private final boolean hasIpa;
  private final SupplierCache<Integer> functionScoreCache;

  public ClaimInfo(int dataVersion, boolean hasIpa, List<Assessment> assessments) {
    this.dataVersion = dataVersion;
    this.hasIpa = hasIpa;
    Map<String, Assessment> temp = new HashMap<>();
    Map<String, Assessment> temp2 = new HashMap<>();
    for (Assessment ast : assessments) {
      temp.put(ast.getName(), ast);
      temp2.put(ast.getItem(), ast);

    }
    this.byName = Collections.unmodifiableMap(temp);
    this.byItem = Collections.unmodifiableMap(temp2);

    this.functionScoreCache = SnfCache.of(() -> {
      Set<Rai300> bedMobility = Const.bedMobility.get(this.hasIpa());
      Set<Rai300> transfer = Const.transfer.get(this.hasIpa());
      Set<Rai300> eatToilet = Const.eatToilet.get(this.hasIpa());

      int score = this.calculateFunctionScore(Const.assessmentValueFunction, bedMobility, transfer,
          Collections.emptySet(), eatToilet);
      return score;
    });
  }

  public static ClaimInfo of(int version, boolean ipa, List<Assessment> assessments) {
    return new ClaimInfo(version, ipa, assessments);
  }

  public boolean hasIpa() {
    return hasIpa;
  }

  public int getDataVersion() {
    return this.dataVersion;
  }

  /**
   * Return an assessment's value if it is available. Otherwise return Assessment.NULL_VALUE.
   * Iterate through the list of assessments to identify the assessment based on its Rai300 field
   *
   * @return the value of the assessment or -1 if not found
   */
  public int getAssessmentValue(Rai300 field, int nullValue) {
    Assessment foundAssessment = this.byItem.get(field.name());
    int result =
        SnfUtils.nullCheck2(foundAssessment, () -> nullValue, () -> foundAssessment.getValueInt());
    return result;
  }

  public int getAssessmentValue(Rai300 field) {
    return this.getAssessmentValue(field, Assessment.NULL_VALUE);
  }

  /**
   * Return the first assessment found within <code>assessments</code> based on the Rai300 field.
   * Otherwise, return a default Assessment with value Assessment.NULL_VALUE (not checked).
   *
   * @return the first Assessment that is found, or null.
   */
  public Assessment getAssessment(Rai300 field) {
    return this.byItem.get(field.name());
  }

  /**
   * Given a list of assessment, check if any of the assessments are present. Meaning that if at
   * least one assessment has the value of 1 then this condition returns true.
   *
   * @return true if any rai300s assessment is present within assessments and has value of 1
   */
  public boolean isAnyAssessmentValuesPresent(Set<Rai300> rai300s) {
    boolean result = rai300s.stream().map(Rai300::name).anyMatch((name) -> {
      boolean rs = this.byItem.containsKey(name) && this.byItem.get(name).isCheck();
      return rs;

    });
    return result;
  }

  public boolean isAnyAssessmentValuesGreaterThanN(Set<Rai300> rai300s, int n) {
    boolean result = rai300s.stream().map(Rai300::name).anyMatch((name) -> {
      boolean rs = this.byItem.containsKey(name) && this.byItem.get(name).getValueInt() > n;
      return rs;
    });
    return result;
  }

  /**
   * Given a list of assessment, check if any of the assessments are present. Meaning that if at
   * least one assessment has the value of 1 then this condition returns true.
   *
   * @return true if any rai300s assessment is present within assessments and has value of 1
   */
  public Boolean isCheckedAndNotNull(Rai300 rai300) {
    boolean result = this.isAnyAssessmentValuesPresent(SnfUtils.toSet(rai300));
    return result;
  }

  /**
   * Count the number of checked Rai300 items within the assessment list
   *
   * @return the number of rai300s that are present within assessments
   */
  public int countAssessmentPresent(Set<Rai300> rai300s) {
    int result = (int) rai300s.stream().map(Rai300::name).filter((name) -> {
      boolean rs = this.byItem.get(name) != null && this.byItem.get(name).isCheck();
      return rs;
    }).count();

    return result;
  }

  public Set<String> getNtaCategories(List<SnfDiagnosisCode> codes) {
    return codes.stream().map(SnfDiagnosisCode::getNtaCategory).collect(Collectors.toSet());
  }

  public boolean hasAssessmentOf(Rai300 field, Predicate<Assessment> condition) {
    Assessment ast = this.byItem.get(field.name());
    boolean result = SnfUtils.nullCheck2(ast, () -> false, () -> condition.test(ast));
    return result;
  }

  /**
   * Determine the resident’s cognitive status based on the staff assessment rather than on resident
   * interview
   * <a href="doc-files/mds-3.0-rai-manual-v1.17.1_october_2019.pdf#page=702" class="req">Step3</a>
   *
   * @return if one of the condition exists
   */
  public boolean isClassifiedBehavioralSymptomsCognitivePerformance(Supplier<Integer> b0700Supplier,
      Supplier<Integer> c0700Supplier, Supplier<Integer> c1000Supplier,
      Supplier<Boolean> isComaAndNoActivities) {
    int b0700 = b0700Supplier.get();
    int c0700 = c0700Supplier.get();
    int c1000 = c1000Supplier.get();

    return (SnfUtils.meetFirstNConditions(
        Arrays.asList(SnfUtils.of(b0700 > 0), SnfUtils.of(c0700 == 1), SnfUtils.of(c1000 > 0)), 2)
        && (b0700 >= 2 || c1000 >= 2)) || c1000 == 3 || isComaAndNoActivities.get();
  }

  /**
   * Convert admission performance score to function score for PT/OT and Nursing, based on
   * Performance_Recode.csv
   *
   * @return converted score
   */
  public int performanceRecode(Supplier<Integer> score) {
    PerformanceRecodeRow row = SnfTables.get(SnfTables.performanceRecodeTable, score.get(),
        BasicRow.getVersionSelector(), dataVersion);

    if (row != null) {
      return row.getFunctionScore();
    } else {
      return 0;
    }
  }

  public int getFunctionScore() {
    return this.functionScoreCache.get();
  }

  protected int calculateFunctionScore(
      TriFunction<Integer, ClaimInfo, String, Integer> functionalAssessments,
      Set<Rai300> bedMobilityList, Set<Rai300> transferList, Set<Rai300> walkingList,
      Set<Rai300> generalItemList) {

    Predicate<Rai300> predicate = (rai) -> this.getAssessment(rai) != null
        && this.getAssessment(rai).getValueInt() != Assessment.NULL_VALUE;
    bedMobilityList = bedMobilityList.stream().filter(predicate).collect(Collectors.toSet());
    transferList = transferList.stream().filter(predicate).collect(Collectors.toSet());
    walkingList = walkingList.stream().filter(predicate).collect(Collectors.toSet());
    generalItemList = generalItemList.stream().filter(predicate).collect(Collectors.toSet());

    return calculateFunctionScoreString(functionalAssessments, ClaimInfo.getString(bedMobilityList),
        ClaimInfo.getString(transferList), ClaimInfo.getString(walkingList),
        ClaimInfo.getString(generalItemList));
  }

  /**
   * Calculate function score based on the list passed in since different payment component
   * calculate function score based of different items.
   *
   * @return total function score
   */
  public int calculateFunctionScoreString(
      TriFunction<Integer, ClaimInfo, String, Integer> functionalAssessments,
      List<String> bedMobilityList, List<String> transferList, List<String> walkingList,
      List<String> generalItemList) {

    final int scale = 3;

    List<Integer> beds =
        bedMobilityList.stream().map((item) -> functionalAssessments.apply(dataVersion, this, item))
            .collect(Collectors.toList());
    BigDecimal bedSum = SnfComparator.sum(beds);
    BigDecimal avgBedMobility = bedSum.divide(new BigDecimal(2), scale, RoundingMode.HALF_UP);

    List<Integer> transfer =
        transferList.stream().map((item) -> functionalAssessments.apply(dataVersion, this, item))
            .collect(Collectors.toList());
    BigDecimal transferSum = SnfComparator.sum(transfer);
    BigDecimal avgTransfer = transferSum.divide(new BigDecimal(3), scale, RoundingMode.HALF_UP);

    List<Integer> walking =
        walkingList.stream().map((item) -> functionalAssessments.apply(dataVersion, this, item))
            .collect(Collectors.toList());
    BigDecimal walkingSum = SnfComparator.sum(walking);
    BigDecimal avgWalking = walkingSum.divide(new BigDecimal(2), scale, RoundingMode.HALF_UP);

    List<Integer> generalItemValues =
        generalItemList.stream().map((item) -> functionalAssessments.apply(dataVersion, this, item))
            .collect(Collectors.toList());
    BigDecimal result =
        SnfComparator.sum(generalItemValues).add(avgBedMobility).add(avgTransfer).add(avgWalking);
    result = result.setScale(0, RoundingMode.HALF_UP);
    return result.intValue();
  }

  /**
   * Check for a number of items, if B0100(Coma) is 1 and list of activities ((GG0130A1, GG0130C1,
   * GG0170B1, GG0170C1, GG0170D1, GG0170E1, and GG0170F1) or (GG0130A5, GG0130C5, GG0170B5,
   * GG0170C5, GG0170D5, GG0170E5, and GG0170F5) if IPA) all equal to 1,9, or 88, then return true.
   *
   * @return if coma and no activities at all
   */
  public boolean isComaAndNoActivities(Supplier<Integer> b0100Supplier) {
    final int b0100 = b0100Supplier.get();

    Set<Rai300> activitySet = new HashSet<>(Const.bedMobility.get(this.hasIpa()));
    activitySet.addAll(Const.transfer.get(this.hasIpa()));
    activitySet.addAll(Const.eatToilet.get(this.hasIpa()));

    final List<Integer> activitiesCheck = Arrays.asList(1, 9, 88);

    // Identify activities that has either 1, 9, or 88 value
    boolean isDependentOrNoActivityOccurred = activitySet.stream().allMatch((item) -> {
      int value = this.getAssessmentValue(item);
      return activitiesCheck.contains(value);
    });

    return b0100 == 1 && isDependentOrNoActivityOccurred;
  }

  public static List<String> getString(Collection<Rai300> items) {
    List<String> result = items.stream().map(rai -> rai.name()).collect(Collectors.toList());
    return result;
  }

  public Set<String> getAssessmentNames() {
    return getAssessmentNames(null);
  }

  public Set<String> getAssessmentNames(Predicate<Assessment> checkedAssessment) {
    if (checkedAssessment == null) {
      checkedAssessment = (ast) -> Boolean.TRUE;
    }
    return this.byName.values().stream().filter(checkedAssessment).map(Assessment::getName)
        .collect(Collectors.toSet());
  }

  public Set<Assessment> getAssessments() {
    return SnfUtils.toSet(this.byItem.values());
  }

}
